{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "No module named 'matplotlib'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-9e3324102725>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'matplotlib'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'inline'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/lib/python3.5/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_line_magic\u001b[0;34m(self, magic_name, line, _stack_depth)\u001b[0m\n\u001b[1;32m   2129\u001b[0m                 \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'local_ns'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getframe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstack_depth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf_locals\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2130\u001b[0m             \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuiltin_trap\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2131\u001b[0;31m                 \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2132\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m</home/jinbo/anaconda3/lib/python3.5/site-packages/decorator.py:decorator-gen-107>\u001b[0m in \u001b[0;36mmatplotlib\u001b[0;34m(self, line)\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.5/site-packages/IPython/core/magic.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(f, *a, **k)\u001b[0m\n\u001b[1;32m    185\u001b[0m     \u001b[0;31m# but it's overkill for just that one bit of state.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    186\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mmagic_deco\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m         \u001b[0mcall\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.5/site-packages/IPython/core/magics/pylab.py\u001b[0m in \u001b[0;36mmatplotlib\u001b[0;34m(self, line)\u001b[0m\n\u001b[1;32m     97\u001b[0m             \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Available matplotlib backends: %s\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mbackends_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     98\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m             \u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_matplotlib\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgui\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    100\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_show_matplotlib_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.5/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36menable_matplotlib\u001b[0;34m(self, gui)\u001b[0m\n\u001b[1;32m   3035\u001b[0m         \"\"\"\n\u001b[1;32m   3036\u001b[0m         \u001b[0;32mfrom\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcore\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpylabtools\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3037\u001b[0;31m         \u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfind_gui_and_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpylab_gui_select\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3038\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3039\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mgui\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'inline'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.5/site-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36mfind_gui_and_backend\u001b[0;34m(gui, gui_select)\u001b[0m\n\u001b[1;32m    271\u001b[0m     \"\"\"\n\u001b[1;32m    272\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 273\u001b[0;31m     \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    274\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    275\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mgui\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mgui\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'auto'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mImportError\u001b[0m: No module named 'matplotlib'"
     ]
    }
   ],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "训练一个分类器\n",
    "=====================\n",
    "\n",
    "上一讲中已经看到如何去定义一个神经网络，计算损失值和更新网络的权重。\n",
    "你现在可能在想下一步。\n",
    "\n",
    "\n",
    "关于数据？\n",
    "----------------\n",
    "\n",
    "一般情况下处理图像、文本、音频和视频数据时，可以使用标准的Python包来加载数据到一个numpy数组中。\n",
    "然后把这个数组转换成 ``torch.*Tensor``。\n",
    "\n",
    "-  图像可以使用 Pillow, OpenCV\n",
    "-  音频可以使用 scipy, librosa\n",
    "-  文本可以使用原始Python和Cython来加载，或者使用 NLTK或\n",
    "   SpaCy 处理\n",
    "\n",
    "特别的，对于图像任务，我们创建了一个包\n",
    "``torchvision``，它包含了处理一些基本图像数据集的方法。这些数据集包括\n",
    "Imagenet, CIFAR10, MNIST 等。除了数据加载以外，``torchvision`` 还包含了图像转换器，\n",
    "``torchvision.datasets`` 和 ``torch.utils.data.DataLoader``。\n",
    "\n",
    "``torchvision``包不仅提供了巨大的便利，也避免了代码的重复。\n",
    "\n",
    "在这个教程中，我们使用CIFAR10数据集，它有如下10个类别\n",
    "：‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,\n",
    "‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’。CIFAR-10的图像都是\n",
    "3x32x32大小的，即，3颜色通道，32x32像素。\n",
    "\n",
    "![](https://pytorch.org/tutorials/_images/cifar10.png)\n",
    "\n",
    "\n",
    "训练一个图像分类器\n",
    "----------------------------\n",
    "\n",
    "依次按照下列顺序进行：\n",
    "\n",
    "1. 使用``torchvision``加载和归一化CIFAR10训练集和测试集\n",
    "2. 定义一个卷积神经网络\n",
    "3. 定义损失函数\n",
    "4. 在训练集上训练网络\n",
    "5. 在测试集上测试网络\n",
    "\n",
    "\n",
    "1. 读取和归一化 CIFAR10\n",
    "------------------------------\n",
    "\n",
    "使用``torchvision``可以非常容易地加载CIFAR10。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "torchvision的输出是[0,1]的PILImage图像，我们把它转换为归一化范围为[-1, 1]的张量。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
    "                                        download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,\n",
    "                                          shuffle=True, num_workers=2)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
    "                                       download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=4,\n",
    "                                         shuffle=False, num_workers=2)\n",
    "\n",
    "classes = ('plane', 'car', 'bird', 'cat',\n",
    "           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们展示一些训练图像。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# 展示图像的函数\n",
    "\n",
    "\n",
    "def imshow(img):\n",
    "    img = img / 2 + 0.5     # unnormalize\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "\n",
    "\n",
    "# 获取随机数据\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "# 展示图像\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "# 显示图像标签\n",
    "print(' '.join('%5s' % classes[labels[j]] for j in range(4)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "2. 定义一个卷积神经网络\n",
    "-------------------------------\n",
    "从之前的神经网络一节复制神经网络代码，并修改为输入3通道图像。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "net = Net()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "3. 定义损失函数和优化器\n",
    "----------------------------------------\n",
    "\n",
    "我们使用交叉熵作为损失函数，使用带动量的随机梯度下降。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "4. 训练网路\n",
    "--------------------------------\n",
    "有趣的时刻开始了。\n",
    "我们只需在数据迭代器上循环，将数据输入给网络，并优化。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(2):  # 多批次循环\n",
    "\n",
    "    running_loss = 0.0\n",
    "    for i, data in enumerate(trainloader, 0):\n",
    "        # 获取输入\n",
    "        inputs, labels = data\n",
    "\n",
    "        # 梯度置0\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # 正向传播，反向传播，优化\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # 打印状态信息\n",
    "        running_loss += loss.item()\n",
    "        if i % 2000 == 1999:    # 每2000批次打印一次\n",
    "            print('[%d, %5d] loss: %.3f' %\n",
    "                  (epoch + 1, i + 1, running_loss / 2000))\n",
    "            running_loss = 0.0\n",
    "\n",
    "print('Finished Training')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "5. 在测试集上测试网络\n",
    "-------------------------------------\n",
    "\n",
    "我们在整个训练集上进行了2次训练，但是我们需要检查网络是否从数据集中学习到有用的东西。\n",
    "通过预测神经网络输出的类别标签与实际情况标签进行对比来进行检测。\n",
    "如果预测正确，我们把该样本添加到正确预测列表。\n",
    "第一步，显示测试集中的图片并熟悉图片内容。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataiter = iter(testloader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "# 显示图片\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们看看神经网络认为以上图片是什么。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = net(images)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "输出是10个标签的能量。\n",
    "一个类别的能量越大，神经网络越认为它是这个类别。所以让我们得到最高能量的标签。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, predicted = torch.max(outputs, 1)\n",
    "\n",
    "print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]\n",
    "                              for j in range(4)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "结果看来不错。\n",
    "\n",
    "接下来让看看网络在整个测试集上的结果如何。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        outputs = net(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "print('Accuracy of the network on the 10000 test images: %d %%' % (\n",
    "    100 * correct / total))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "结果看起来不错，至少比随机选择要好，随机选择的正确率为10%。\n",
    "似乎网络学习到了一些东西。\n",
    "\n",
    "\n",
    "\n",
    "在识别哪一个类的时候好，哪一个不好呢？\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_correct = list(0. for i in range(10))\n",
    "class_total = list(0. for i in range(10))\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        outputs = net(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        c = (predicted == labels).squeeze()\n",
    "        for i in range(4):\n",
    "            label = labels[i]\n",
    "            class_correct[label] += c[i].item()\n",
    "            class_total[label] += 1\n",
    "\n",
    "\n",
    "for i in range(10):\n",
    "    print('Accuracy of %5s : %2d %%' % (\n",
    "        classes[i], 100 * class_correct[i] / class_total[i]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下一步?\n",
    "\n",
    "我们如何在GPU上运行神经网络呢？\n",
    "\n",
    "在GPU上训练\n",
    "----------------\n",
    "把一个神经网络移动到GPU上训练就像把一个Tensor转换GPU上一样简单。并且这个操作会递归遍历有所模块，并将其参数和缓冲区转换为CUDA张量。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# 确认我们的电脑支持CUDA，然后显示CUDA信息：\n",
    "\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本节的其余部分假定`device`是CUDA设备。\n",
    "\n",
    "然后这些方法将递归遍历所有模块并将模块的参数和缓冲区\n",
    "转换成CUDA张量：\n",
    "\n",
    "\n",
    "```python\n",
    "\n",
    "    net.to(device)\n",
    "```\n",
    "\n",
    "记住：inputs 和 targets 也要转换。\n",
    "\n",
    "```python\n",
    "\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "```\n",
    "为什么我们没注意到GPU的速度提升很多？那是因为网络非常的小。\n",
    "\n",
    "**实践:** \n",
    "尝试增加你的网络的宽度（第一个``nn.Conv2d``的第2个参数，第二个``nn.Conv2d``的第一个参数，它们需要是相同的数字），看看你得到了什么样的加速。\n",
    "\n",
    "**实现的目标**:\n",
    "\n",
    "- 深入了解了PyTorch的张量库和神经网络\n",
    "- 训练了一个小网络来分类图片\n",
    "\n",
    "***译者注：后面我们教程会训练一个真正的网络，使识别率达到90%以上。***\n",
    "\n",
    "多GPU训练\n",
    "-------------------------\n",
    "如果你想使用所有的GPU得到更大的加速，\n",
    "请查看[数据并行处理](5_data_parallel_tutorial.ipynb)。\n",
    "\n",
    "下一步？\n",
    "-------------------\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "-  :doc:`训练神经网络玩电子游戏 </intermediate/reinforcement_q_learning>`\n",
    "-  `在ImageNet上训练最好的ResNet`\n",
    "-  `使用对抗生成网络来训练一个人脸生成器`\n",
    "-  `使用LSTM网络训练一个字符级的语言模型`\n",
    "-  `更多示例`\n",
    "-  `更多教程`\n",
    "-  `在论坛上讨论PyTorch`\n",
    "-  `Slack上与其他用户讨论`\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
